{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "FaissEvalTopKColabToGPT2DatasetDownLoadEdition.ipynb",
      "version": "0.3.2",
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "TPU"
  },
  "cells": [
    {
      "metadata": {
        "id": "JcrrSVNi8S5-",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "#May have a memory crash in CPU or GPU mode. If so, run in TPU mode. \n",
        "\n",
        "#The way I set it up was mounting it from my drive. If you wish to do the same you'll need the FFNNEmbed CSV files found here\n",
        "#  https://drive.google.com/open?id=1ee6I9OHrCiN-wv5BtZX75nyOiWuWTnfs\n",
        "# and put them in your drive in a folder named 'FFNNEmbeds'\n",
        "\n",
        "#run all cells. The cell at the bottom creates a CSV file"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "odozicZsVrMZ",
        "colab_type": "code",
        "outputId": "d6798eeb-9d86-468c-feb8-0aee74024d7f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 121
        }
      },
      "cell_type": "code",
      "source": [
        "import pandas as pd\n",
        "import numpy as np\n",
        "import os\n",
        "import csv\n",
        "from tqdm import tqdm\n",
        "\n",
        "!pip install -U -q PyDrive\n",
        "\n",
        "from google.colab import files\n",
        "from pydrive.auth import GoogleAuth\n",
        "from pydrive.drive import GoogleDrive\n",
        "from google.colab import auth\n",
        "from oauth2client.client import GoogleCredentials\n",
        "\n",
        "auth.authenticate_user()\n",
        "gauth = GoogleAuth()\n",
        "gauth.credentials = GoogleCredentials.get_application_default()\n",
        "drive = GoogleDrive(gauth)\n",
        "# drive.mount('/content/gdrive')\n",
        "# files = os.listdir(\"/content/gdrive/My Drive/FFNNEmbeds\")"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code\n",
            "\n",
            "Enter your authorization code:\n",
            "··········\n",
            "Mounted at /content/gdrive\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "mmrJD2VIN6fw",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def download_file_from_google_drive(id, destination):\n",
        "    URL = \"https://docs.google.com/uc?export=download\"\n",
        "\n",
        "    session = requests.Session()\n",
        "\n",
        "    response = session.get(URL, params = { 'id' : id }, stream = True)\n",
        "    token = get_confirm_token(response)\n",
        "\n",
        "    if token:\n",
        "        params = { 'id' : id, 'confirm' : token }\n",
        "        response = session.get(URL, params = params, stream = True)\n",
        "\n",
        "    save_response_content(response, destination)    \n",
        "\n",
        "def get_confirm_token(response):\n",
        "    for key, value in response.cookies.items():\n",
        "        if key.startswith('download_warning'):\n",
        "            return value\n",
        "\n",
        "    return None\n",
        "\n",
        "def save_response_content(response, destination):\n",
        "    CHUNK_SIZE = 32768\n",
        "\n",
        "    with open(destination, \"wb\") as f:\n",
        "        for chunk in response.iter_content(CHUNK_SIZE):\n",
        "            if chunk: # filter out keep-alive new chunks\n",
        "                f.write(chunk)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "Pi272ybDN98B",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "file_id = '1hbk2z_zwTwKPZ3yKyVjjvSaj5rbnmBmj'\n",
        "\n",
        "download_file_from_google_drive(file_id, 'eHealthForums.csv')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "E9VU-v7mOKb_",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "file_id = '1rUU6wrEW72dbC3D7yJB1G5QdUg9UImS6'\n",
        "\n",
        "download_file_from_google_drive(file_id, 'healthTapFFNNEmbeddings.csv')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "vhPPNdQgOLk3",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "file_id = '1vsLTKh6A7Tzpyf_qHHYgYG4wzX71ThoF'\n",
        "\n",
        "download_file_from_google_drive(file_id, 'askDocsFFNNEmbeddings.csv')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "yRUiqpLoOLTg",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "file_id = '1y6LzAZDhNwMtnz6QFUwMWKENzayvzeUF'\n",
        "\n",
        "download_file_from_google_drive(file_id, 'webMDFFNNEmbeddings.csv')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "ONRtJ9kxOK9p",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "file_id = '15x44Pglx3bg46Bqp2HaWR1OW3LFOURXr'\n",
        "\n",
        "download_file_from_google_drive(file_id, 'iClinicFFNNEmbeddings.csv')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "lvesd_SRN95p",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "files = os.listdir()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "ftB3GRCJo_dh",
        "colab_type": "code",
        "outputId": "074d756c-5e23-4fcb-b137-5a7d200cc193",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        }
      },
      "cell_type": "code",
      "source": [
        "files"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "['iClinicFFNNEmbeddings.csv',\n",
              " 'webMDFFNNEmbeddings.csv',\n",
              " 'askDocsFFNNEmbeddings.csv',\n",
              " 'healthTapFFNNEmbeddings.csv',\n",
              " 'eHealthForums.csv']"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "metadata": {
        "id": "JW0C5N-UWtwT",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def fix_array(x):\n",
        "    x = np.fromstring(\n",
        "    x.replace('\\n','')\n",
        "    .replace('[','')\n",
        "    .replace(']','')\n",
        "    .replace('  ',' '), sep=' ')\n",
        "    return x.reshape((1, 768))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "niyPY7RMz40j",
        "colab_type": "code",
        "outputId": "b69f279a-f224-455f-bb32-46856c2f95b5",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "cell_type": "code",
      "source": [
        "pd.options.display.max_rows = 700\n",
        "pd.set_option('expand_frame_repr', True)\n",
        "pd.set_option('max_colwidth', 250)\n",
        "pd.get_option(\"display.max_rows\")"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "700"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    },
    {
      "metadata": {
        "id": "WqlnLnWnWvqi",
        "colab_type": "code",
        "outputId": "825c29e6-8193-49b5-c16c-e7d9622cd1af",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84
        }
      },
      "cell_type": "code",
      "source": [
        "qa = pd.read_csv(\"/content/gdrive/My Drive/FFNNEmbeds/\" + files[0])\n",
        "for file in files[1:]:\n",
        "    print(file)\n",
        "    qa = pd.concat([qa, pd.read_csv(\"/content/gdrive/My Drive/FFNNEmbeds/\" + file)], axis = 0)"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "webMDFFNNEmbeddings.csv\n",
            "askDocsFFNNEmbeddings.csv\n",
            "healthTapFFNNEmbeddings.csv\n",
            "eHealthForums.csv\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "wGh5zlJs1eOE",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "del qa['Unnamed: 0']"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "hsJKRYwCWyzS",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "qa[\"Q_FFNN_embeds\"] = qa[\"Q_FFNN_embeds\"].apply(fix_array)\n",
        "qa[\"A_FFNN_embeds\"] = qa[\"A_FFNN_embeds\"].apply(fix_array)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "JoSZ5ImRW1-T",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "# qa = qa.sample(frac = 1)\n",
        "# qa.reset_index(inplace = True, drop = True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "DteXdlBh2jBl",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "qa = qa.reset_index(drop=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "erQwDeii1Jns",
        "colab_type": "code",
        "outputId": "a7ce2210-f896-4991-dfca-fe19cd7ad2ba",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 615
        }
      },
      "cell_type": "code",
      "source": [
        "qa.head()"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>question</th>\n",
              "      <th>answer</th>\n",
              "      <th>Q_FFNN_embeds</th>\n",
              "      <th>A_FFNN_embeds</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>is it fine to exercise with knee pain?</td>\n",
              "      <td>from your description it appears that you may have anterior knee pain which sometimes presents as pain at the back of the knee. the second possibility is that you have over done your exercise and hamstrings are sore and the lower end of the knee ...</td>\n",
              "      <td>[[-0.0236410219, 0.128144488, 0.30177781, 0.430800796, 0.0534969121, 1.24597561, 0.908191383, -0.0682020858, 1.49799263, 0.0182549823, 0.159414783, -0.000211991704, 0.000532273552, -0.0226245634, -0.0949951187, -0.0198594704, 1.25414896, 0.360615...</td>\n",
              "      <td>[[-0.115900427, 0.182273567, -0.0992501155, -0.0385564938, 0.740532875, 0.167836308, 0.0228974223, -0.23636511, -0.172942773, 0.25693363, -0.24061957, -0.0263501219, 0.0366276056, -0.126324296, -0.00830298476, 0.158679247, 0.0368386209, -0.393908...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>suffering from anxiety restlessness and taking clonazepam &amp; mirtazapine. i had depression &amp; panic attacks for the last 7 years. need a second opinion.</td>\n",
              "      <td>depression anxiety restlessness and panic attacks are best respond to a combination of a selective serotonin reuptake inhibitors (ssris) and benzodiazepines. mirtazapine is not the most efficient ant depressant and it is not a true ssri as fluoxe...</td>\n",
              "      <td>[[-0.114040159, 0.140852138, 0.0999100953, 0.208959475, 0.445060819, -0.403621316, 0.0552569143, -0.0890201628, 0.380254924, 0.108175397, 0.191386163, 0.198272184, -0.133397803, 0.212966055, 0.352456003, -0.07531479, 0.341809332, 0.151971966, 0.2...</td>\n",
              "      <td>[[-0.155445561, 0.0941457227, 0.0411367081, 0.255660385, 0.30229032, -0.253930658, -0.176501781, 0.917308688, -0.183687523, 0.209650502, -0.0855835676, 0.239034563, 0.762762547, 0.259996265, 0.34940359, -0.0778813362, -0.315282226, -0.212972909, ...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>can a thyroid patient eat soybean and fenugreek?</td>\n",
              "      <td>patients with hypothyroidism usually gain weight. your issue is weight loss which is seen in hyperthyroidism. as your thyroid tsh is in the normal range you can continue the same treatment. for gaining weight you need to take food which is adequa...</td>\n",
              "      <td>[[-0.0945982337, 0.04440815, -0.219836012, 0.196798429, 0.402243942, -0.0414064191, 1.45660985, 0.214540154, 0.410930932, -0.0685333759, 1.27774835, 0.206304967, -0.320117086, 0.0697160363, 0.302823365, -0.181685016, 1.5156101, 0.679543853, -0.08...</td>\n",
              "      <td>[[-0.267755091, 0.496627122, -0.10593427, 0.103044294, 0.482136637, 0.210582703, -0.194033623, 0.417719126, -0.306303591, 0.23062554, -0.251927078, -0.0518328063, 0.22350128, 0.335046589, -0.0103447549, 0.497870088, 0.153194278, -0.197104171, 0.4...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>i am not getting my periods after taking fenugreek seeds. why?</td>\n",
              "      <td>fenugreek seed cannot affect your fertility do not worry. a delayed period could be because of stress or hormonal imbalance so please get a urine pregnancy test to rule out pregnancy first. for further queries consult an infertility specialist on...</td>\n",
              "      <td>[[-0.0772762671, -0.0546492599, -0.290007412, 0.234665141, -0.0717474371, 0.210957736, 0.501454949, -0.33825165, 0.535232365, 0.141104475, 0.105892427, 0.167244852, -0.226345539, 0.0772434399, 0.36760062, -0.121957906, 0.0198328495, 1.51193035, 0...</td>\n",
              "      <td>[[-0.116682649, 0.229011983, 0.0717253983, 0.0508261435, 0.188567623, 0.305067122, -0.131368637, -0.294420958, -0.202981532, 0.0689166486, -0.0496700145, 0.0921775177, 0.522677302, 0.173455626, 0.19158718, -0.0330378525, -0.165551513, -0.17966866...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>kindly suggest me a therapy to overcome heat allergy.</td>\n",
              "      <td>your problem is a characteristic of cholinergic urticaria. it is a type of urticaria where patients are allergic to their own sweat. so whenever a patient sweats for example due to heat sun stress exercise. etc. the patient develops hives. levoce...</td>\n",
              "      <td>[[0.190717205, -0.109118752, 0.225994617, 0.448829651, -0.368060827, 1.09241807, 0.22018972, 0.1560826, -0.269190162, 0.220193535, 1.38247335, 0.249987021, -0.202662915, -0.0893628523, 0.185544237, -0.270314127, 0.238500059, -0.357885599, 0.36902...</td>\n",
              "      <td>[[-0.0469765291, 0.159166887, -0.205679953, -0.120664336, 0.337732494, -0.0699605793, -0.248624474, 0.400170565, -0.30279389, 0.209822923, -0.101327896, -0.0589761622, 0.456502318, 0.149086922, -0.0414450914, -0.176266044, -0.028876394, -0.345402...</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "                                                                                                                                                 question  \\\n",
              "0                                                                                                                  is it fine to exercise with knee pain?   \n",
              "1  suffering from anxiety restlessness and taking clonazepam & mirtazapine. i had depression & panic attacks for the last 7 years. need a second opinion.   \n",
              "2                                                                                                        can a thyroid patient eat soybean and fenugreek?   \n",
              "3                                                                                          i am not getting my periods after taking fenugreek seeds. why?   \n",
              "4                                                                                                   kindly suggest me a therapy to overcome heat allergy.   \n",
              "\n",
              "                                                                                                                                                                                                                                                      answer  \\\n",
              "0  from your description it appears that you may have anterior knee pain which sometimes presents as pain at the back of the knee. the second possibility is that you have over done your exercise and hamstrings are sore and the lower end of the knee ...   \n",
              "1  depression anxiety restlessness and panic attacks are best respond to a combination of a selective serotonin reuptake inhibitors (ssris) and benzodiazepines. mirtazapine is not the most efficient ant depressant and it is not a true ssri as fluoxe...   \n",
              "2  patients with hypothyroidism usually gain weight. your issue is weight loss which is seen in hyperthyroidism. as your thyroid tsh is in the normal range you can continue the same treatment. for gaining weight you need to take food which is adequa...   \n",
              "3  fenugreek seed cannot affect your fertility do not worry. a delayed period could be because of stress or hormonal imbalance so please get a urine pregnancy test to rule out pregnancy first. for further queries consult an infertility specialist on...   \n",
              "4  your problem is a characteristic of cholinergic urticaria. it is a type of urticaria where patients are allergic to their own sweat. so whenever a patient sweats for example due to heat sun stress exercise. etc. the patient develops hives. levoce...   \n",
              "\n",
              "                                                                                                                                                                                                                                               Q_FFNN_embeds  \\\n",
              "0  [[-0.0236410219, 0.128144488, 0.30177781, 0.430800796, 0.0534969121, 1.24597561, 0.908191383, -0.0682020858, 1.49799263, 0.0182549823, 0.159414783, -0.000211991704, 0.000532273552, -0.0226245634, -0.0949951187, -0.0198594704, 1.25414896, 0.360615...   \n",
              "1  [[-0.114040159, 0.140852138, 0.0999100953, 0.208959475, 0.445060819, -0.403621316, 0.0552569143, -0.0890201628, 0.380254924, 0.108175397, 0.191386163, 0.198272184, -0.133397803, 0.212966055, 0.352456003, -0.07531479, 0.341809332, 0.151971966, 0.2...   \n",
              "2  [[-0.0945982337, 0.04440815, -0.219836012, 0.196798429, 0.402243942, -0.0414064191, 1.45660985, 0.214540154, 0.410930932, -0.0685333759, 1.27774835, 0.206304967, -0.320117086, 0.0697160363, 0.302823365, -0.181685016, 1.5156101, 0.679543853, -0.08...   \n",
              "3  [[-0.0772762671, -0.0546492599, -0.290007412, 0.234665141, -0.0717474371, 0.210957736, 0.501454949, -0.33825165, 0.535232365, 0.141104475, 0.105892427, 0.167244852, -0.226345539, 0.0772434399, 0.36760062, -0.121957906, 0.0198328495, 1.51193035, 0...   \n",
              "4  [[0.190717205, -0.109118752, 0.225994617, 0.448829651, -0.368060827, 1.09241807, 0.22018972, 0.1560826, -0.269190162, 0.220193535, 1.38247335, 0.249987021, -0.202662915, -0.0893628523, 0.185544237, -0.270314127, 0.238500059, -0.357885599, 0.36902...   \n",
              "\n",
              "                                                                                                                                                                                                                                               A_FFNN_embeds  \n",
              "0  [[-0.115900427, 0.182273567, -0.0992501155, -0.0385564938, 0.740532875, 0.167836308, 0.0228974223, -0.23636511, -0.172942773, 0.25693363, -0.24061957, -0.0263501219, 0.0366276056, -0.126324296, -0.00830298476, 0.158679247, 0.0368386209, -0.393908...  \n",
              "1  [[-0.155445561, 0.0941457227, 0.0411367081, 0.255660385, 0.30229032, -0.253930658, -0.176501781, 0.917308688, -0.183687523, 0.209650502, -0.0855835676, 0.239034563, 0.762762547, 0.259996265, 0.34940359, -0.0778813362, -0.315282226, -0.212972909, ...  \n",
              "2  [[-0.267755091, 0.496627122, -0.10593427, 0.103044294, 0.482136637, 0.210582703, -0.194033623, 0.417719126, -0.306303591, 0.23062554, -0.251927078, -0.0518328063, 0.22350128, 0.335046589, -0.0103447549, 0.497870088, 0.153194278, -0.197104171, 0.4...  \n",
              "3  [[-0.116682649, 0.229011983, 0.0717253983, 0.0508261435, 0.188567623, 0.305067122, -0.131368637, -0.294420958, -0.202981532, 0.0689166486, -0.0496700145, 0.0921775177, 0.522677302, 0.173455626, 0.19158718, -0.0330378525, -0.165551513, -0.17966866...  \n",
              "4  [[-0.0469765291, 0.159166887, -0.205679953, -0.120664336, 0.337732494, -0.0699605793, -0.248624474, 0.400170565, -0.30279389, 0.209822923, -0.101327896, -0.0589761622, 0.456502318, 0.149086922, -0.0414450914, -0.176266044, -0.028876394, -0.345402...  "
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 12
        }
      ]
    },
    {
      "metadata": {
        "id": "vnjuTC7LW4Jv",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "question_bert = np.concatenate(qa[\"Q_FFNN_embeds\"].values, axis=0)\n",
        "answer_bert = np.concatenate(qa[\"A_FFNN_embeds\"].values, axis=0)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "ODTXaWyDW6SK",
        "colab_type": "code",
        "outputId": "6c2c73d4-19d0-475d-84ec-d39255b1cc2b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "cell_type": "code",
      "source": [
        "answer_bert.shape\n"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(239402, 768)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 14
        }
      ]
    },
    {
      "metadata": {
        "id": "8_ank7uIW8Nb",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "question_bert = question_bert.astype('float32')\n",
        "answer_bert = answer_bert.astype('float32')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "Q1DnMFKlx6As",
        "colab_type": "code",
        "outputId": "a93652fa-3ec8-4a6a-c08e-98b994ecda4b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1045
        }
      },
      "cell_type": "code",
      "source": [
        "#just the CPU version of FAISS, will have to look deeper on how to get GPU version, but works fast enough for now \n",
        "\n",
        "!wget  https://anaconda.org/pytorch/faiss-cpu/1.2.1/download/linux-64/faiss-cpu-1.2.1-py36_cuda9.0.176_1.tar.bz2\n",
        "!tar xvjf faiss-cpu-1.2.1-py36_cuda9.0.176_1.tar.bz2\n",
        "!cp -r lib/python3.6/site-packages/* /usr/local/lib/python3.6/dist-packages/\n",
        "!pip install mkl"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "--2019-04-29 19:17:20--  https://anaconda.org/pytorch/faiss-cpu/1.2.1/download/linux-64/faiss-cpu-1.2.1-py36_cuda9.0.176_1.tar.bz2\n",
            "Resolving anaconda.org (anaconda.org)... 104.17.93.24, 104.17.92.24, 2606:4700::6811:5d18, ...\n",
            "Connecting to anaconda.org (anaconda.org)|104.17.93.24|:443... connected.\n",
            "HTTP request sent, awaiting response... 302 FOUND\n",
            "Location: https://binstar-cio-packages-prod.s3.amazonaws.com/5a15c9c5c376961204909d87/5aa7f0a65571b411e5c259be?response-content-disposition=attachment%3B%20filename%3D%22faiss-cpu-1.2.1-py36_cuda9.0.176_1.tar.bz2%22%3B%20filename%2A%3DUTF-8%27%27faiss-cpu-1.2.1-py36_cuda9.0.176_1.tar.bz2&response-content-type=application%2Fx-tar&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Expires=60&X-Amz-Date=20190429T191721Z&X-Amz-SignedHeaders=host&X-Amz-Security-Token=AgoJb3JpZ2luX2VjEHgaCXVzLWVhc3QtMSJHMEUCIQD8dyA2%2B7TsLOyNCsf2w4LA%2BX5MPQ54wemDyZ4JM6KCdwIgG4oUJOOJzQD27dd0XtTAvnp2MqzIQybMlF%2FNrIL%2BXpkq2gMIYRAAGgw0NTU4NjQwOTgzNzgiDCP8vFYJJ4ZE3Iupjiq3A60V0WA8hpQb7tZ7LT3RGC5UqYiMf%2FtwImmWXlY6MdjhUM5ejlG3etCQw12PmA%2Fud21GUzfpI3G2bbCjpECPZiO7kWPYVJcibO0V%2BAcpBWXsX9gL%2FXavuvw1Rf7LKLrTDKkn5eKjgRJSCuI7lGJciNcFpGvgsqfnc1wXusL2ei9d83TTpH3tAWjTW9R9mW%2BGq0QcD5XXqKymZyFQUA3VEK2FDkFE%2FBwT3HFnBDrQqmTM3fTWjMZo0tTTn7qyom7AdLoT%2BVw7EVk6a4myrNh2o9ATEFgurtuRtAxHJsgwpFiDHUUZb1yG3%2BDYfxKQLyMRpkziGi60Ch69FWvkkjcyl9Mrt6pc4ZJJDDCzx6BIGxDZGZOv7LGtrxyQeWPaa5Oy4TxX%2BE54SUK9u%2Blf%2BDkAhTs5wiiGRc8zj2qtbDZF5rZaVK22z4yuSyg8ypRzxKzLz%2FkuEYg9WUPP5hBDcg%2Fcsxu9gRYohht8fcs1X%2FYzGXw6VH7lHIV7VcbQ2zVKI3libAbZlYe8%2FhDz8Kz7oeobDwHq9VI3%2BX5W9KDvwmSU2WeBU%2FVQyQRzjHGuPfVUv6%2BI4PYK5tCYxnYwu7ac5gU6tAFpnjXgpcTEZJjA3d2NJceKzZUqddtQNPCnuifIQeJV5%2BeiymH%2FT7%2FUO38x%2FWCoO0N%2FxEsg5LBL9gqa9OQ6r%2BiUZATGJs%2BFy9incLWnR9e01A0RiF%2Bw3u22zCYLq7hv67ViN0zitbB%2B0a%2FZPBgVb14Yhp8p325SxoY5EE1wPF9N85PTV1v%2FFX5kooeGS1zgjNh9qAPzNYJj1Q%2FRTpwqvlNhrgNej%2FP4nqRVBvhM2rJPxsFRJFw%3D&X-Amz-Credential=ASIAWUI46DZFHDZVNVOG%2F20190429%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=6c8dc77826eb9a6d9c7970c90116716e1162a4069d209ddf2bcffc9444aa7796 [following]\n",
            "--2019-04-29 19:17:21--  https://binstar-cio-packages-prod.s3.amazonaws.com/5a15c9c5c376961204909d87/5aa7f0a65571b411e5c259be?response-content-disposition=attachment%3B%20filename%3D%22faiss-cpu-1.2.1-py36_cuda9.0.176_1.tar.bz2%22%3B%20filename%2A%3DUTF-8%27%27faiss-cpu-1.2.1-py36_cuda9.0.176_1.tar.bz2&response-content-type=application%2Fx-tar&X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Expires=60&X-Amz-Date=20190429T191721Z&X-Amz-SignedHeaders=host&X-Amz-Security-Token=AgoJb3JpZ2luX2VjEHgaCXVzLWVhc3QtMSJHMEUCIQD8dyA2%2B7TsLOyNCsf2w4LA%2BX5MPQ54wemDyZ4JM6KCdwIgG4oUJOOJzQD27dd0XtTAvnp2MqzIQybMlF%2FNrIL%2BXpkq2gMIYRAAGgw0NTU4NjQwOTgzNzgiDCP8vFYJJ4ZE3Iupjiq3A60V0WA8hpQb7tZ7LT3RGC5UqYiMf%2FtwImmWXlY6MdjhUM5ejlG3etCQw12PmA%2Fud21GUzfpI3G2bbCjpECPZiO7kWPYVJcibO0V%2BAcpBWXsX9gL%2FXavuvw1Rf7LKLrTDKkn5eKjgRJSCuI7lGJciNcFpGvgsqfnc1wXusL2ei9d83TTpH3tAWjTW9R9mW%2BGq0QcD5XXqKymZyFQUA3VEK2FDkFE%2FBwT3HFnBDrQqmTM3fTWjMZo0tTTn7qyom7AdLoT%2BVw7EVk6a4myrNh2o9ATEFgurtuRtAxHJsgwpFiDHUUZb1yG3%2BDYfxKQLyMRpkziGi60Ch69FWvkkjcyl9Mrt6pc4ZJJDDCzx6BIGxDZGZOv7LGtrxyQeWPaa5Oy4TxX%2BE54SUK9u%2Blf%2BDkAhTs5wiiGRc8zj2qtbDZF5rZaVK22z4yuSyg8ypRzxKzLz%2FkuEYg9WUPP5hBDcg%2Fcsxu9gRYohht8fcs1X%2FYzGXw6VH7lHIV7VcbQ2zVKI3libAbZlYe8%2FhDz8Kz7oeobDwHq9VI3%2BX5W9KDvwmSU2WeBU%2FVQyQRzjHGuPfVUv6%2BI4PYK5tCYxnYwu7ac5gU6tAFpnjXgpcTEZJjA3d2NJceKzZUqddtQNPCnuifIQeJV5%2BeiymH%2FT7%2FUO38x%2FWCoO0N%2FxEsg5LBL9gqa9OQ6r%2BiUZATGJs%2BFy9incLWnR9e01A0RiF%2Bw3u22zCYLq7hv67ViN0zitbB%2B0a%2FZPBgVb14Yhp8p325SxoY5EE1wPF9N85PTV1v%2FFX5kooeGS1zgjNh9qAPzNYJj1Q%2FRTpwqvlNhrgNej%2FP4nqRVBvhM2rJPxsFRJFw%3D&X-Amz-Credential=ASIAWUI46DZFHDZVNVOG%2F20190429%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Signature=6c8dc77826eb9a6d9c7970c90116716e1162a4069d209ddf2bcffc9444aa7796\n",
            "Resolving binstar-cio-packages-prod.s3.amazonaws.com (binstar-cio-packages-prod.s3.amazonaws.com)... 52.216.145.171\n",
            "Connecting to binstar-cio-packages-prod.s3.amazonaws.com (binstar-cio-packages-prod.s3.amazonaws.com)|52.216.145.171|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 4106816 (3.9M) [application/x-tar]\n",
            "Saving to: ‘faiss-cpu-1.2.1-py36_cuda9.0.176_1.tar.bz2’\n",
            "\n",
            "faiss-cpu-1.2.1-py3 100%[===================>]   3.92M  17.5MB/s    in 0.2s    \n",
            "\n",
            "2019-04-29 19:17:21 (17.5 MB/s) - ‘faiss-cpu-1.2.1-py36_cuda9.0.176_1.tar.bz2’ saved [4106816/4106816]\n",
            "\n",
            "info/hash_input.json\n",
            "info/has_prefix\n",
            "info/index.json\n",
            "info/git\n",
            "info/files\n",
            "info/LICENSE.txt\n",
            "info/about.json\n",
            "info/paths.json\n",
            "lib/python3.6/site-packages/faiss-0.1-py3.6.egg-info/dependency_links.txt\n",
            "lib/python3.6/site-packages/faiss-0.1-py3.6.egg-info/not-zip-safe\n",
            "lib/python3.6/site-packages/faiss-0.1-py3.6.egg-info/requires.txt\n",
            "lib/python3.6/site-packages/faiss-0.1-py3.6.egg-info/top_level.txt\n",
            "lib/python3.6/site-packages/faiss-0.1-py3.6.egg-info/native_libs.txt\n",
            "info/test/run_test.py\n",
            "info/test/run_test.sh\n",
            "info/test/tests/run_tests.sh\n",
            "lib/python3.6/site-packages/faiss-0.1-py3.6.egg-info/SOURCES.txt\n",
            "info/recipe/conda_build_config.yaml\n",
            "info/recipe/build.sh\n",
            "info/test/tests/CMakeLists.txt\n",
            "info/test/tests/Makefile\n",
            "info/recipe/meta.yaml.template\n",
            "lib/python3.6/site-packages/faiss-0.1-py3.6.egg-info/PKG-INFO\n",
            "info/test/tests/test_factory.py\n",
            "info/test/tests/test_ivfpq_codec.cpp\n",
            "info/recipe/meta.yaml\n",
            "info/recipe/setup.py\n",
            "info/test/tests/test_blas.cpp\n",
            "info/recipe/makefile.inc\n",
            "info/test/tests/test_ivfpq_indexing.cpp\n",
            "info/test/tests/test_ondisk_ivf.cpp\n",
            "info/test/tests/test_build_blocks.py\n",
            "info/test/tests/test_merge.cpp\n",
            "info/test/tests/test_pairs_decoding.cpp\n",
            "info/test/tests/test_index_composite.py\n",
            "lib/python3.6/site-packages/faiss/__init__.py\n",
            "lib/python3.6/site-packages/faiss/__pycache__/__init__.cpython-36.pyc\n",
            "info/test/tests/test_index.py\n",
            "info/test/tests/test_blas\n",
            "lib/python3.6/site-packages/faiss/__pycache__/swigfaiss.cpython-36.pyc\n",
            "lib/python3.6/site-packages/faiss/swigfaiss.py\n",
            "lib/python3.6/site-packages/faiss/_swigfaiss.so\n",
            "Requirement already satisfied: mkl in /usr/local/lib/python3.6/dist-packages (2019.0)\n",
            "Requirement already satisfied: intel-openmp in /usr/local/lib/python3.6/dist-packages (from mkl) (2019.0)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "metadata": {
        "id": "oMnnEA8tyA5r",
        "colab_type": "code",
        "outputId": "741dd429-430c-425d-b8fa-a6d28684e742",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 50
        }
      },
      "cell_type": "code",
      "source": [
        "import faiss"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Failed to load GPU Faiss: No module named 'faiss.swigfaiss_gpu'\n",
            "Faiss falling back to CPU-only.\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "metadata": {
        "id": "AO6Tp_IuW-YC",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "\n",
        "answer_index = faiss.IndexFlatIP(answer_bert.shape[-1])\n",
        "answer_index.add(answer_bert)\n",
        "\n",
        "question_index = faiss.IndexFlatIP(question_bert.shape[-1])\n",
        "question_index.add(question_bert)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "VISMzNtxvQRy",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "def topKforGPT2(xf, kf):\n",
        "    D1, I1 = answer_index.search(question_bert[xf:xf+1].astype('float32'), kf)\n",
        "    D2, I2 = question_index.search(question_bert[xf:xf+1].astype('float32'), kf)\n",
        "    return I1, I2"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "metadata": {
        "id": "8G5XqKpc30uR",
        "colab_type": "code",
        "outputId": "78b02c1d-b0f7-4e26-a387-c35b626873cb",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "cell_type": "code",
      "source": [
        "\n",
        "numberSamples=10\n",
        "\n",
        "output = open('GPT2data.csv', \"w\")\n",
        "writer = csv.writer(output)\n",
        "\n",
        "firstrow = ['question', 'answer']\n",
        "for ii in range(0, numberSamples):\n",
        "    firstrow.append('question'+str(ii))\n",
        "    firstrow.append('answer'+str(ii))\n",
        "\n",
        "writer.writerow(firstrow)\n",
        "\n",
        "# for k in tqdm(range(1000), mininterval=30, maxinterval=60):\n",
        "for k in tqdm(range(qa.shape[0]), mininterval=30, maxinterval=60):\n",
        "    rowfill=[]\n",
        "    rowfill.append(qa[\"question\"][k])\n",
        "    rowfill.append(qa[\"answer\"][k])\n",
        "    aa, qq = topKforGPT2(k, int(numberSamples/2))\n",
        "    aaa = qa2.loc[list(aa[0]), :]\n",
        "    qqq = qa2.loc[list(qq[0]), :]\n",
        "    aaaa =  [*sum(zip( list(aaa['question']) ,list(aaa['answer'])  ),())]\n",
        "    qqqq = [*sum(zip( list(qqq['question']) ,list(qqq['answer'])  ),())]  \n",
        "    finalfill=aaaa+qqqq\n",
        "    rowfill = rowfill + finalfill\n",
        "    writer.writerow(rowfill)\n",
        "    auth.authenticate_user()\n",
        "    gauth = GoogleAuth()\n",
        "    gauth.credentials = GoogleCredentials.get_application_default()\n",
        "    drive = GoogleDrive(gauth)\n",
        "output.close()"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 1000/1000 [03:59<00:00,  4.18it/s]\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "metadata": {
        "id": "3IifYduc4Lp6",
        "colab_type": "code",
        "colab": {}
      },
      "cell_type": "code",
      "source": [
        "uploadModel = drive.CreateFile() \n",
        "uploadModel.SetContentFile('GPT2data.csv')\n",
        "uploadModel.Upload()"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}